#!/usr/bin/env python3
"""
Phase 4c: Three-way interactions (Z × KAOPEN × income_group)
Phase 4d: Pension model on EBA-49 subsample

Tests whether Z×KAOPEN interactions vary by income level, and whether
pension-demographics interactions are a more robust AE-specific finding.
"""

import sys
import os
import pandas as pd
import numpy as np
from pathlib import Path

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.model import PanelGLS, estimate_baseline_model
from src.macro import EBA_COUNTRIES, SSA_COUNTRIES, EU_EXPANSION, EXPANSION_TIER1, filter_eba_sample

OUTPUT_DIR = FOLLOWUP_DIR / "output" / "tables"

# ─── Load data ───────────────────────────────────────────────────────────
panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")

# Filter to estimation sample
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

# Full 140-country sample
full_sample = filter_eba_sample(est, extended=True, expansion=True)
print(f"\nFull sample: {full_sample['iso3'].nunique()} countries, {len(full_sample):,} obs")

# ═══════════════════════════════════════════════════════════════════════════
# PHASE 4c: THREE-WAY INTERACTIONS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("PHASE 4c: THREE-WAY INTERACTIONS (Z × KAOPEN × INCOME GROUP)")
print("=" * 70)

# --- Income group classification ---
# Use World Bank-style income groups based on GDP per capita PPP
# We'll classify using the panel's own gdp_pc_ppp variable

# First check availability
if 'gdp_pc_ppp' in full_sample.columns:
    gdp_pc = full_sample.groupby('iso3')['gdp_pc_ppp'].median()
    print(f"\nGDP per capita PPP available for {gdp_pc.notna().sum()} countries")
    print(f"  Range: ${gdp_pc.min():,.0f} - ${gdp_pc.max():,.0f}")
    print(f"  Median: ${gdp_pc.median():,.0f}")
else:
    # Fallback: use GDP per capita from WEO
    print("gdp_pc_ppp not available, computing from WEO data...")
    if 'ngdp_usd' in full_sample.columns and 'population_weo' in full_sample.columns:
        full_sample['gdp_pc_approx'] = full_sample['ngdp_usd'] / full_sample['population_weo']
        gdp_pc = full_sample.groupby('iso3')['gdp_pc_approx'].median()
    else:
        # Use output per worker as proxy
        gdp_pc = full_sample.groupby('iso3')['output_per_worker'].median()
    print(f"Proxy GDP per capita for {gdp_pc.notna().sum()} countries")

# Classify into three groups: High (AE), Middle, Low
# Use terciles of the sample's own distribution
q33 = gdp_pc.quantile(0.33)
q67 = gdp_pc.quantile(0.67)

income_map = {}
for iso, val in gdp_pc.items():
    if pd.isna(val):
        income_map[iso] = 'middle'  # default
    elif val >= q67:
        income_map[iso] = 'high'
    elif val >= q33:
        income_map[iso] = 'middle'
    else:
        income_map[iso] = 'low'

full_sample['income_group'] = full_sample['iso3'].map(income_map)

# Report group composition
for grp in ['high', 'middle', 'low']:
    countries = full_sample[full_sample['income_group'] == grp]['iso3'].unique()
    print(f"\n  {grp.upper()} income ({len(countries)} countries):")
    print(f"    {', '.join(sorted(countries)[:15])}{'...' if len(countries) > 15 else ''}")

# --- Build three-way interaction variables ---
# Strategy: Create Z×KAOPEN interactions separately for each income group
# This is equivalent to Z × KAOPEN × I(income=high), Z × KAOPEN × I(income=middle), etc.

full_sample['is_high'] = (full_sample['income_group'] == 'high').astype(float)
full_sample['is_middle'] = (full_sample['income_group'] == 'middle').astype(float)
full_sample['is_low'] = (full_sample['income_group'] == 'low').astype(float)

# Three-way interactions: Z_i × KAOPEN × income_dummy
for z in ['Z_1', 'Z_2', 'Z_3']:
    for grp in ['high', 'middle', 'low']:
        col = f'{z}_x_kaopen_x_{grp}'
        if 'kaopen' in full_sample.columns and z in full_sample.columns:
            full_sample[col] = full_sample[z] * full_sample['kaopen'] * full_sample[f'is_{grp}']

# Also create Z × income interactions (without KAOPEN) as controls
for z in ['Z_1', 'Z_2', 'Z_3']:
    for grp in ['high', 'middle', 'low']:
        full_sample[f'{z}_x_{grp}'] = full_sample[z] * full_sample[f'is_{grp}']

# --- Estimate models ---
# Need observations with all required variables
demo_vars = ['Z_1', 'Z_2', 'Z_3']
control_vars = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                'log_rel_opw', 'health_exp_gdp']
rate_var = ['log_lending_rate']

# Model A: Baseline + Z×KAOPEN (standard, for comparison on 140-country sample)
# Model B: Baseline + Z×KAOPEN×income (three-way)
# Model C: Baseline + Z×income + Z×KAOPEN×income (full three-way with main effects)

# Filter to complete cases for the extended model
ext_vars = demo_vars + control_vars + rate_var + ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
comp = full_sample.dropna(subset=ext_vars + ['ca_gdp']).copy()
print(f"\nComplete cases for extended model: {len(comp):,} obs, {comp['iso3'].nunique()} countries")

# Count by income group
for grp in ['high', 'middle', 'low']:
    n = comp[comp['income_group'] == grp]
    print(f"  {grp}: {n['iso3'].nunique()} countries, {len(n):,} obs")

# Model A: Standard two-way interactions (reference)
print("\n--- Model A: Standard Z×KAOPEN interactions ---")
vars_a = demo_vars + control_vars + rate_var + ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
y_a = comp['ca_gdp'].values
X_a = comp[vars_a].values

gls_a = PanelGLS()
gls_a.fit(y_a, X_a, comp['iso3'].values, comp['year'].values)
gls_a.feature_names = vars_a

print(f"  R² = {gls_a.r_squared:.4f}, N = {gls_a.n_obs}, rho = {gls_a.rho:.3f}")
for i, v in enumerate(vars_a):
    if 'x_kaopen' in v:
        sig = '***' if gls_a.pvalues[i] < 0.001 else '**' if gls_a.pvalues[i] < 0.01 else '*' if gls_a.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_a.beta[i]:.3f} (p={gls_a.pvalues[i]:.4f}) {sig}")

# Model B: Three-way interactions (Z × KAOPEN × income)
print("\n--- Model B: Three-way Z×KAOPEN×income interactions ---")
three_way_vars = []
for z in ['Z_1', 'Z_2', 'Z_3']:
    for grp in ['high', 'middle', 'low']:
        three_way_vars.append(f'{z}_x_kaopen_x_{grp}')

vars_b = demo_vars + control_vars + rate_var + three_way_vars
y_b = comp['ca_gdp'].values
X_b = comp[vars_b].values

gls_b = PanelGLS()
gls_b.fit(y_b, X_b, comp['iso3'].values, comp['year'].values)
gls_b.feature_names = vars_b

print(f"  R² = {gls_b.r_squared:.4f}, N = {gls_b.n_obs}, rho = {gls_b.rho:.3f}")
print(f"\n  Three-way interaction coefficients:")
for i, v in enumerate(vars_b):
    if 'x_kaopen_x_' in v:
        sig = '***' if gls_b.pvalues[i] < 0.001 else '**' if gls_b.pvalues[i] < 0.01 else '*' if gls_b.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_b.beta[i]:.3f} (se={gls_b.se[i]:.3f}, p={gls_b.pvalues[i]:.4f}) {sig}")

# Model C: Full specification with Z×income main effects + three-way
print("\n--- Model C: Z×income + Z×KAOPEN×income (full) ---")
income_main_vars = []
for z in ['Z_1', 'Z_2', 'Z_3']:
    # Use high and low (middle is reference)
    income_main_vars.append(f'{z}_x_high')
    income_main_vars.append(f'{z}_x_low')

vars_c = demo_vars + control_vars + rate_var + income_main_vars + three_way_vars
y_c = comp['ca_gdp'].values
X_c = comp[vars_c].values

gls_c = PanelGLS()
gls_c.fit(y_c, X_c, comp['iso3'].values, comp['year'].values)
gls_c.feature_names = vars_c

print(f"  R² = {gls_c.r_squared:.4f}, N = {gls_c.n_obs}, rho = {gls_c.rho:.3f}")
print(f"\n  Z×income main effects:")
for i, v in enumerate(vars_c):
    if '_x_high' in v or '_x_low' in v:
        sig = '***' if gls_c.pvalues[i] < 0.001 else '**' if gls_c.pvalues[i] < 0.01 else '*' if gls_c.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_c.beta[i]:.3f} (se={gls_c.se[i]:.3f}, p={gls_c.pvalues[i]:.4f}) {sig}")
print(f"\n  Three-way interaction coefficients:")
for i, v in enumerate(vars_c):
    if 'x_kaopen_x_' in v:
        sig = '***' if gls_c.pvalues[i] < 0.001 else '**' if gls_c.pvalues[i] < 0.01 else '*' if gls_c.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_c.beta[i]:.3f} (se={gls_c.se[i]:.3f}, p={gls_c.pvalues[i]:.4f}) {sig}")

# Joint F-test: do three-way interactions add anything over standard two-way?
# Compare Model A vs Model B
from scipy import stats
k_restriction = len(three_way_vars) - 3  # three-way has 9 vars, two-way has 3 = 6 extra
n_a = gls_a.n_obs
n_b = gls_b.n_obs
r2_a = gls_a.r_squared
r2_b = gls_b.r_squared

# Actually, better to compare restricted (Model A with 3 interactions) vs unrestricted (Model B with 9)
# Model B nests Model A if we constrain high=middle=low for each Z×KAOPEN
# F = ((R²_u - R²_r) / q) / ((1 - R²_u) / (n - k - 1))
q = 6  # 9 three-way params - 3 two-way params = 6 additional restrictions
k_b = len(vars_b)
F_threeway = ((r2_b - r2_a) / q) / ((1 - r2_b) / (n_b - k_b - 1))
p_threeway = 1 - stats.f.cdf(F_threeway, q, n_b - k_b - 1)
print(f"\n  Joint F-test (three-way vs two-way): F({q},{n_b-k_b-1}) = {F_threeway:.3f}, p = {p_threeway:.4f}")

# Save results
results_4c = []
for i, v in enumerate(vars_b):
    results_4c.append({
        'model': 'three_way',
        'variable': v,
        'coefficient': gls_b.beta[i],
        'std_error': gls_b.se[i],
        't_stat': gls_b.tvalues[i],
        'p_value': gls_b.pvalues[i],
    })
for i, v in enumerate(vars_a):
    results_4c.append({
        'model': 'two_way_reference',
        'variable': v,
        'coefficient': gls_a.beta[i],
        'std_error': gls_a.se[i],
        't_stat': gls_a.tvalues[i],
        'p_value': gls_a.pvalues[i],
    })

pd.DataFrame(results_4c).to_csv(OUTPUT_DIR / "three_way_interactions.csv", index=False)
print(f"\nSaved three-way interaction results to {OUTPUT_DIR / 'three_way_interactions.csv'}")


# ═══════════════════════════════════════════════════════════════════════════
# PHASE 4d: PENSION MODEL ON EBA-49
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("PHASE 4d: PENSION MODEL ON EBA-49 SUBSAMPLE")
print("=" * 70)

# Filter to EBA-49 only
eba49 = filter_eba_sample(est, extended=False, expansion=False)
print(f"\nEBA-49 sample: {eba49['iso3'].nunique()} countries, {len(eba49):,} obs")

# Check pension data availability
pension_vars = ['pension_spending_gdp', 'pension_coverage']
for pv in pension_vars:
    if pv in eba49.columns:
        n = eba49[pv].notna().sum()
        nc = eba49.loc[eba49[pv].notna(), 'iso3'].nunique()
        print(f"  {pv}: {n:,} obs, {nc} countries")
    else:
        print(f"  {pv}: NOT in panel")

# Use pension_spending_gdp as the main pension variable
if 'pension_spending_gdp' in eba49.columns:
    # Create Z × pension interactions
    for z in ['Z_1', 'Z_2', 'Z_3']:
        eba49[f'{z}_x_pension'] = eba49[z] * eba49['pension_spending_gdp']

    # Model D1: Baseline on EBA-49 (for comparison)
    print("\n--- Model D1: Baseline on EBA-49 ---")
    base_vars = demo_vars + ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                              'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
    comp_d1 = eba49.dropna(subset=base_vars + ['ca_gdp']).copy()
    print(f"  Complete cases: {len(comp_d1):,} obs, {comp_d1['iso3'].nunique()} countries")

    y_d1 = comp_d1['ca_gdp'].values
    X_d1 = comp_d1[base_vars].values
    gls_d1 = PanelGLS()
    gls_d1.fit(y_d1, X_d1, comp_d1['iso3'].values, comp_d1['year'].values)
    gls_d1.feature_names = base_vars
    print(f"  R² = {gls_d1.r_squared:.4f}, N = {gls_d1.n_obs}")

    # Model D2: Baseline + pension_spending_gdp on EBA-49
    print("\n--- Model D2: Baseline + pension on EBA-49 ---")
    pension_base_vars = base_vars + ['pension_spending_gdp']
    comp_d2 = eba49.dropna(subset=pension_base_vars + ['ca_gdp']).copy()
    print(f"  Complete cases: {len(comp_d2):,} obs, {comp_d2['iso3'].nunique()} countries")

    y_d2 = comp_d2['ca_gdp'].values
    X_d2 = comp_d2[pension_base_vars].values
    gls_d2 = PanelGLS()
    gls_d2.fit(y_d2, X_d2, comp_d2['iso3'].values, comp_d2['year'].values)
    gls_d2.feature_names = pension_base_vars
    print(f"  R² = {gls_d2.r_squared:.4f}, N = {gls_d2.n_obs}")
    pens_idx = pension_base_vars.index('pension_spending_gdp')
    sig = '***' if gls_d2.pvalues[pens_idx] < 0.001 else '**' if gls_d2.pvalues[pens_idx] < 0.01 else '*' if gls_d2.pvalues[pens_idx] < 0.05 else ''
    print(f"  pension_spending_gdp: {gls_d2.beta[pens_idx]:.3f} (p={gls_d2.pvalues[pens_idx]:.4f}) {sig}")

    # Model D3: Baseline + pension + Z×pension interactions on EBA-49
    print("\n--- Model D3: Baseline + pension + Z×pension interactions on EBA-49 ---")
    z_pension_vars = ['Z_1_x_pension', 'Z_2_x_pension', 'Z_3_x_pension']
    pension_int_vars = pension_base_vars + z_pension_vars
    comp_d3 = eba49.dropna(subset=pension_int_vars + ['ca_gdp']).copy()
    print(f"  Complete cases: {len(comp_d3):,} obs, {comp_d3['iso3'].nunique()} countries")

    y_d3 = comp_d3['ca_gdp'].values
    X_d3 = comp_d3[pension_int_vars].values
    gls_d3 = PanelGLS()
    gls_d3.fit(y_d3, X_d3, comp_d3['iso3'].values, comp_d3['year'].values)
    gls_d3.feature_names = pension_int_vars
    print(f"  R² = {gls_d3.r_squared:.4f}, N = {gls_d3.n_obs}")

    print(f"\n  All coefficients:")
    for i, v in enumerate(pension_int_vars):
        sig = '***' if gls_d3.pvalues[i] < 0.001 else '**' if gls_d3.pvalues[i] < 0.01 else '*' if gls_d3.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_d3.beta[i]:.4f} (se={gls_d3.se[i]:.4f}, p={gls_d3.pvalues[i]:.4f}) {sig}")

    # Joint F-test: Z×pension interactions
    q_pens = 3  # three Z×pension terms
    k_d3 = len(pension_int_vars)
    n_d3 = gls_d3.n_obs
    F_pension = ((gls_d3.r_squared - gls_d2.r_squared) / q_pens) / ((1 - gls_d3.r_squared) / (n_d3 - k_d3 - 1))
    p_pension = 1 - stats.f.cdf(F_pension, q_pens, n_d3 - k_d3 - 1)
    print(f"\n  Joint F-test (Z×pension interactions): F({q_pens},{n_d3-k_d3-1}) = {F_pension:.3f}, p = {p_pension:.4f}")

    # Model D4: Baseline + pension + Z×pension + Z×KAOPEN on EBA-49
    # Test whether pension interactions survive when KAOPEN interactions are also included
    print("\n--- Model D4: Full model with both pension and KAOPEN interactions ---")
    if 'Z_1_x_kaopen' in eba49.columns:
        kaopen_int = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
        full_int_vars = pension_int_vars + kaopen_int
        comp_d4 = eba49.dropna(subset=full_int_vars + ['ca_gdp']).copy()
        print(f"  Complete cases: {len(comp_d4):,} obs, {comp_d4['iso3'].nunique()} countries")

        if len(comp_d4) > len(full_int_vars) + 10:
            y_d4 = comp_d4['ca_gdp'].values
            X_d4 = comp_d4[full_int_vars].values
            gls_d4 = PanelGLS()
            gls_d4.fit(y_d4, X_d4, comp_d4['iso3'].values, comp_d4['year'].values)
            gls_d4.feature_names = full_int_vars
            print(f"  R² = {gls_d4.r_squared:.4f}, N = {gls_d4.n_obs}")

            print(f"\n  Interaction coefficients (both sets):")
            for i, v in enumerate(full_int_vars):
                if 'x_pension' in v or 'x_kaopen' in v:
                    sig = '***' if gls_d4.pvalues[i] < 0.001 else '**' if gls_d4.pvalues[i] < 0.01 else '*' if gls_d4.pvalues[i] < 0.05 else ''
                    print(f"  {v}: {gls_d4.beta[i]:.4f} (se={gls_d4.se[i]:.4f}, p={gls_d4.pvalues[i]:.4f}) {sig}")
        else:
            print(f"  Insufficient observations ({len(comp_d4)}) for full model")
            gls_d4 = None
    else:
        print("  Z×KAOPEN interactions not available")
        gls_d4 = None

    # Model D5: Same pension model on FULL 140-country sample for comparison
    print("\n--- Model D5: Pension model on full 140-country sample ---")
    for z in ['Z_1', 'Z_2', 'Z_3']:
        full_sample[f'{z}_x_pension'] = full_sample[z] * full_sample['pension_spending_gdp']

    comp_d5 = full_sample.dropna(subset=pension_int_vars + ['ca_gdp']).copy()
    print(f"  Complete cases: {len(comp_d5):,} obs, {comp_d5['iso3'].nunique()} countries")

    if len(comp_d5) > len(pension_int_vars) + 10:
        y_d5 = comp_d5['ca_gdp'].values
        X_d5 = comp_d5[pension_int_vars].values
        gls_d5 = PanelGLS()
        gls_d5.fit(y_d5, X_d5, comp_d5['iso3'].values, comp_d5['year'].values)
        gls_d5.feature_names = pension_int_vars
        print(f"  R² = {gls_d5.r_squared:.4f}, N = {gls_d5.n_obs}")

        print(f"\n  Pension interaction coefficients (full sample):")
        for i, v in enumerate(pension_int_vars):
            if 'pension' in v:
                sig = '***' if gls_d5.pvalues[i] < 0.001 else '**' if gls_d5.pvalues[i] < 0.01 else '*' if gls_d5.pvalues[i] < 0.05 else ''
                print(f"  {v}: {gls_d5.beta[i]:.4f} (se={gls_d5.se[i]:.4f}, p={gls_d5.pvalues[i]:.4f}) {sig}")
    else:
        print(f"  Insufficient pension data in full sample ({len(comp_d5)} obs)")
        gls_d5 = None

    # Save pension results
    results_4d = []
    for model_name, gls_obj, var_list in [
        ('D1_baseline_eba49', gls_d1, base_vars),
        ('D2_pension_eba49', gls_d2, pension_base_vars),
        ('D3_pension_interactions_eba49', gls_d3, pension_int_vars),
    ]:
        for i, v in enumerate(var_list):
            results_4d.append({
                'model': model_name,
                'variable': v,
                'coefficient': gls_obj.beta[i],
                'std_error': gls_obj.se[i],
                't_stat': gls_obj.tvalues[i],
                'p_value': gls_obj.pvalues[i],
            })
    if gls_d4 is not None:
        for i, v in enumerate(full_int_vars):
            results_4d.append({
                'model': 'D4_pension_plus_kaopen_eba49',
                'variable': v,
                'coefficient': gls_d4.beta[i],
                'std_error': gls_d4.se[i],
                't_stat': gls_d4.tvalues[i],
                'p_value': gls_d4.pvalues[i],
            })
    if gls_d5 is not None:
        for i, v in enumerate(pension_int_vars):
            results_4d.append({
                'model': 'D5_pension_interactions_full140',
                'variable': v,
                'coefficient': gls_d5.beta[i],
                'std_error': gls_d5.se[i],
                't_stat': gls_d5.tvalues[i],
                'p_value': gls_d5.pvalues[i],
            })

    pd.DataFrame(results_4d).to_csv(OUTPUT_DIR / "pension_model_tests.csv", index=False)
    print(f"\nSaved pension model results to {OUTPUT_DIR / 'pension_model_tests.csv'}")

else:
    print("  pension_spending_gdp not available in panel — skipping pension tests")


# ═══════════════════════════════════════════════════════════════════════════
# SUMMARY
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)

print(f"""
Phase 4c: Three-way interactions
  Model A (standard Z×KAOPEN): R² = {gls_a.r_squared:.4f}
  Model B (Z×KAOPEN×income):   R² = {gls_b.r_squared:.4f}
  F-test (B vs A):             F = {F_threeway:.3f}, p = {p_threeway:.4f}

Phase 4d: Pension model on EBA-49
  Model D1 (baseline):         R² = {gls_d1.r_squared:.4f}, N = {gls_d1.n_obs}
  Model D2 (+ pension):        R² = {gls_d2.r_squared:.4f}, N = {gls_d2.n_obs}
  Model D3 (+ Z×pension):      R² = {gls_d3.r_squared:.4f}, N = {gls_d3.n_obs}
  F-test (Z×pension):          F = {F_pension:.3f}, p = {p_pension:.4f}
""")

if gls_d4 is not None:
    print(f"  Model D4 (+ both):           R² = {gls_d4.r_squared:.4f}, N = {gls_d4.n_obs}")

print("\nDone.")
